# AOSC New Topic Initialze & Create (ANTIC)

from typing import List, Dict, Optional, Iterable, Tuple
from getpass import getpass

import subprocess
import requests
import logging
import sys
import re
import os
import readline
import tempfile
import datetime
import json

MAIN_REPO_PATH = 'github.com/AOSC-Dev/aosc-os-abbs'
MAIN_TARGET_BRANCH = 'stable'
MAIN_ARCHS = [('amd64', 'AMD64'), ('arm64', 'AArch64'),
              # the following are not included in ALL_ARCHS
              ('noarch', 'Architecture-independent'),
              ('optenv32', '32-bit Optional Environment')]
SEC_ARCHS = [('loongson3', 'Loongson 3'),
             ('ppc64el', 'PowerPC 64-bit (Little Endian)'),
             ('riscv64', 'RISC-V 64-bit')]
ALL_ARCHS = [i[0] for i in MAIN_ARCHS[:2] + SEC_ARCHS]

FAIL_ARCH_PATTERN = re.compile(r'FAIL_ARCH=(.+)')
NOARCH_PATTERN = re.compile(r'ABHOST=noarch$')
MAIN_ARCH_FRAGMENT = '**Primary Architectures**'
SEC_ARCH_FRAGMENT = """**Secondary Architectures**

Architectural progress for "secondary," or experimental ports does not impede on merging of this topic."""
GITHUB_API_URL = 'https://api.github.com/graphql'
GITHUB_GQL_HEADERS = {'User-Agent': 'ANTIC/0.1',
                      'Accept': 'application/vnd.github.v4.idl'}
PRECOND_CHECK_QUERY = """query {
  viewer { login }
  repository(name: "aosc-os-abbs", owner: "AOSC-Dev") {
    id
    viewerPermission
    labels(first: 100) { totalCount, nodes { id name } }
    ref(qualifiedName: "refs/heads/%s") {
      id, target { oid }
      associatedPullRequests(first: 5) { totalCount, nodes { number, state } }
    }
  }
}"""
OPEN_PR_MUTATION = """mutation OpenPR($body: String!, $title: String!) {
  createPullRequest(
    input: {repositoryId: "%s", baseRefName: "refs/heads/%s", headRefName: "refs/heads/%s", title: $title, body: $body}
  ) { pullRequest { id number } }
}"""
LABEL_MUTATION = """mutation LabelMutation {
    addLabelsToLabelable(input: {labelableId: "%s", labelIds: %s})
}"""
EDITOR_TEMPLATE = """<!-- ANTIC: Please enter the description of your topic (%s) below. -->
<!-- ANTIC: You can use any markdown syntax as you want. Save and quit the editor after you have finished. -->
<!-- ANTIC: Your title: %s; Current date in UTC: %s -->
"""
PULL_REQUEST_TEMPLATE = """<!-- Generated by ANTIC -->
Topic Description
-----------------

{desc}

Package(s) Affected
-------------------

```
{packages}
```

Security Update?
----------------

{sec_update}

Test Build(s) Done
------------------

{test_arches}

Update(s) Uploaded to Stable
----------------------------

{stable_arches}
"""


class ChangedPackage(object):
    def __init__(self, name: str, arches: List[str], sec_update=False):
        self.name = name
        self.arches = arches
        self.sec_update = sec_update

    def __repr__(self):
        return f'<ChangedPackage: name={self.name}, arches={self.arches}, sec_update={self.sec_update}>'


class RepoInfo(object):
    def __init__(self, id: str, labels: Dict[str, str]):
        self.id = id
        self.labels = labels


# Copied from acbs
def fail_arch_regex(expr: str) -> re.Pattern:
    regex = '^'
    negated = False
    sup_bracket = False
    if len(expr) < 3:
        raise ValueError('Pattern too short.')
    for i, ch in enumerate(expr):
        if i == 0 and ch == '!':
            negated = True
            if expr[1] != '(':
                regex += '('
                sup_bracket = True
            continue
        if negated:
            if ch == '(':
                regex += '(?!'
                continue
            if i == 1 and sup_bracket:
                regex += '?!'
        regex += ch
    if sup_bracket:
        regex += ')'
    return re.compile(regex)


def check_preconditions(branch: str, token: str) -> Optional[RepoInfo]:
    query = {'query': PRECOND_CHECK_QUERY % branch}
    headers = GITHUB_GQL_HEADERS.copy()
    headers['Authorization'] = f'Bearer {token}'
    response = requests.post(GITHUB_API_URL, json=query, headers=headers)
    payload = response.json()
    response.raise_for_status()
    if payload.get('errors'):
        logging.error(payload['message'])
        return None
    data = payload['data']
    is_aosc_member = data['repository']['viewerPermission'] in [
        'ADMIN', 'MAINTAIN', 'WRITE']
    if not is_aosc_member:
        name = data['viewer']['login']
        logging.warning(
            f'You ({name}) are not a member of the AOSC. This tool may be useless.')
    commit_ref = data['repository']['ref']
    if not commit_ref:
        logging.error(
            f'Branch {branch} does not exist in AOSC-Dev/aosc-os-abbs. Please create it first.')
        return None
    if commit_ref['associatedPullRequests']['totalCount'] > 0:
        logging.error(
            f'Branch {branch} already belongs to a pull request. Please do not re-use topic names.')
        return None
    tip_commit = commit_ref['target']['oid']
    ws_tip_commit = subprocess.check_output(
        ['git', 'rev-parse', 'HEAD']).decode('utf-8').strip()
    if tip_commit != ws_tip_commit:
        logging.error(
            'Your local branch is different than the one on the GitHub. Please push or pull the changes.')
        return None
    labels_dict = dict()
    for label in data['repository']['labels']['nodes']:
        labels_dict[label['name']] = label['id']
    return RepoInfo(data['repository']['id'], labels_dict)


def get_upstreamable_remote() -> str:
    remotes = subprocess.check_output(['git', 'remote', '-v']).decode('utf-8')
    for line in remotes.splitlines():
        name, url = line.split('\t', 1)
        if MAIN_REPO_PATH in url:
            return name
    raise RuntimeError(
        'Could not determine which remote is upstreamable to AOSC abbs tree.')


def scan_changed_package(name: str, path: str) -> ChangedPackage:
    if 'extra-optenv32' in path:
        return ChangedPackage(name, ['optenv32'])
    maybe_defines = os.path.join(path, 'autobuild', 'defines')
    with open(maybe_defines, 'rt') as f:
        content = f.read()
    if NOARCH_PATTERN.search(content):
        return ChangedPackage(name, ['noarch'])
    fail_arch = FAIL_ARCH_PATTERN.search(content)
    if not fail_arch:
        return ChangedPackage(name, ALL_ARCHS)
    fail_arch_pat = fail_arch_regex(fail_arch.group(1).strip('"'))
    supported = []
    for arch in ALL_ARCHS:
        if fail_arch_pat.match(arch):
            continue
        supported.append(arch)
    return ChangedPackage(name, supported)


def get_changed_packages(remote: str, topdir: str) -> Dict[str, ChangedPackage]:
    upstream_branch = f'{remote}/{MAIN_TARGET_BRANCH}'
    command = ['git', 'log', '--name-only',
               '--format=', f'{upstream_branch}..HEAD']
    changes = subprocess.check_output(command).decode('utf-8')
    packages: Dict[str, ChangedPackage] = dict()
    for line in changes.splitlines():
        parts = line.split('/', 2)
        if len(parts) < 3:
            continue
        if packages.get(parts[1]):
            continue
        try:
            p = scan_changed_package(
                parts[1], os.path.join(topdir, parts[0], parts[1]))
            packages[parts[1]] = p
        except Exception:
            packages[parts[1]] = ChangedPackage(parts[1], [])
    return packages


def simulate_merging(remote: str) -> bool:
    upstream_branch = f'{remote}/{MAIN_TARGET_BRANCH}'
    base_commit = subprocess.check_output(
        ['git', 'merge-base', upstream_branch, 'HEAD']).decode('utf-8').strip()
    merge_diff = subprocess.check_output(
        ['git', 'merge-tree', base_commit, upstream_branch, 'HEAD'])
    return b'<<<' not in merge_diff


def workspace_clean() -> bool:
    output = subprocess.check_output(['git', 'status', '--porcelain'])
    if len(output) > 3:
        return False
    return True


def ask_title() -> str:
    while True:
        title = input('Please enter a title for the PR: ').strip()
        if not title:
            continue
        if len(title) < 5:
            logging.error('Come on... Enter a better title!')
            continue
        return title


def get_default_editor() -> str:
    return os.environ.get('EDITOR') or os.environ.get('VISUAL') or 'nano'


def generate_arch_fragment(arches: Iterable[str]) -> str:
    main_arch = ''
    for (name, desc) in MAIN_ARCHS:
        if name in arches:
            main_arch += f'- [ ] {desc} `{name}`\n'
    if main_arch:
        main_arch = MAIN_ARCH_FRAGMENT + '\n\n' + main_arch
    sec_arch = ''
    for (name, desc) in SEC_ARCHS:
        if name in arches:
            sec_arch += f'- [ ] {desc} `{name}`\n'
    if sec_arch:
        sec_arch = SEC_ARCH_FRAGMENT + '\n\n' + sec_arch
    return main_arch + '\n\n' + sec_arch


def trim_description(data: str) -> str:
    trimmed = ''
    for line in data.splitlines():
        if line.startswith('<!-- ANTIC:'):
            continue
        trimmed += line.strip() + '\n'
    return trimmed


def translate_labels(pool: Dict[str, str], labels: Iterable[str]) -> List[str]:
    ids = []
    for l in labels:
        maybe_id = pool.get(l)
        if not maybe_id:
            logging.warning(f'Label "{l}" is not found.')
            continue
        ids.append(maybe_id)
    return ids


def create_pull_request(token: str, repo_id: str, branch: str, qvars: Dict[str, str]) -> Tuple[str, int]:
    pr_mutation = OPEN_PR_MUTATION % (repo_id, MAIN_TARGET_BRANCH, branch)
    query = {'query': pr_mutation, 'variables': qvars}
    headers = GITHUB_GQL_HEADERS.copy()
    headers['Authorization'] = f'Bearer {token}'
    response = requests.post(GITHUB_API_URL, json=query, headers=headers)
    response.raise_for_status()
    payload = response.json()
    if payload.get('errors'):
        logging.error(payload['message'])
        raise RuntimeError('Failed to create pull request')
    data = payload['data']
    pr_info = data['createPullRequest']['pullRequest']
    return (pr_info['id'], pr_info['number'])


def attach_labels(token: str, pr_id: str, labels: List[str]):
    lbl_mutation = LABEL_MUTATION % (pr_id, json.dumps(labels))
    query = {'query': lbl_mutation}
    headers = GITHUB_GQL_HEADERS.copy()
    headers['Authorization'] = f'Bearer {token}'
    response = requests.post(GITHUB_API_URL, json=query, headers=headers)
    response.raise_for_status()


def main():
    editor = get_default_editor()
    topdir = subprocess.check_output(
        ['git', 'rev-parse', '--show-toplevel']).decode('utf-8').strip()
    this_branch = subprocess.check_output(
        ['git', 'branch', '--show-current']).decode('utf-8').strip()
    if len(this_branch) < 4:
        logging.error(
            'Branch name is too short. Please use a more descriptive branch name.')
        sys.exit(2)
    readline.parse_and_bind('')
    if not workspace_clean():
        logging.warning(
            'Your workspace contains uncommitted or untracked changes.')
        if input('Do you still wish to continue? [y/N] ').strip().lower() != 'y':
            logging.error('User chose to abort.')
            sys.exit(1)

    target_remote = get_upstreamable_remote()
    logging.info(f'Main repository remote: {target_remote}')
    print('Fetching new changes ...')
    subprocess.check_call(['git', 'fetch', target_remote])
    print('Checking merge status ...')
    safe_to_merge = simulate_merging(target_remote)
    if not safe_to_merge:
        logging.error(
            f'This branch ({this_branch}) is not safe to merge into stable.')
        logging.error(
            f'Please rebase this branch against {target_remote}/{MAIN_TARGET_BRANCH}!')
        sys.exit(2)
    print('The branch is safe to merge.')
    if this_branch == 'retro':
        logging.error(
            'This tool does not handle Mainline/Retro synchronization. Please do it manually.')
        sys.exit(1)
    token = getpass('Please enter your GitHub access token (PAT): ')
    print('Checking pre-conditions ...')
    repo_data = check_preconditions(this_branch, token)
    if not repo_data:
        logging.error('Preconditions not met. Can not continue.')
        sys.exit(1)
    print('Preconditions seem okay.')
    print('Aggregating changes ...')
    changed_packages = get_changed_packages(target_remote, topdir)
    mentioned_arches = set()
    for p in changed_packages.values():
        mentioned_arches.update(p.arches)
    is_survey = False
    if len(changed_packages) > 15:
        logging.warning('You are updating a lot of packages!')
        is_survey = input(
            'Is this a survey topic? [Y/n] ').strip().lower() != 'n'
    print('Good. Now you can open a pull request.')
    arch_frag = generate_arch_fragment(mentioned_arches)
    packages_frag = '\n'.join([p for p in changed_packages.keys()])
    print('Before that, we need to ask you a few questions ...')
    title = ask_title()
    fd, edit_file = tempfile.mkstemp(suffix='.md')
    utc_time = datetime.datetime.now(datetime.timezone.utc).isoformat()
    # TODO: security update question
    # TODO: draft PR question
    with os.fdopen(fd, 'w+') as f:
        f.write(EDITOR_TEMPLATE % (this_branch, title, utc_time))
        f.flush()
        subprocess.check_call([editor, edit_file])
        f.seek(0, 0)
        description = trim_description(f.read())
    final_body = PULL_REQUEST_TEMPLATE.format(
        desc=description, packages=packages_frag,
        test_arches=arch_frag, stable_arches=arch_frag,
        sec_update='No'
    )
    final_labels = []
    # guess and assign labels
    if is_survey:
        final_labels.append('survey')
    if 'optenv32' in mentioned_arches:
        final_labels.append('optenv')
    l_title = title.lower()
    if 'new,' in l_title:
        final_labels.append('new-package')
    if 'update' in l_title or 'upgrade' in l_title:
        final_labels.append('upgrade')
    if 'fix ' in l_title:
        final_labels.append('bug')
    # prepare for the submission
    label_ids = translate_labels(repo_data.labels, final_labels)
    gql_vars = { 'title': title, 'body': final_body }
    pr_id, pr_num = create_pull_request(token, repo_data.id, this_branch, gql_vars)
    attach_labels(token, pr_id, label_ids)
    print(f'Done! Pull request #{pr_num} created! Please see https://{MAIN_REPO_PATH}/pulls/{pr_num} .')


if __name__ == '__main__':
    main()
